#!/usr/bin/env python3
# C8 (REAL v2) — L1/L2/L3 Relation Maps with present-act control, t=0 bootstrap, and rail gating
# Control is boolean/ordinal only (no curve weights, no RNG in control).
# Eligibility(x,y,t) := cooldown==0 AND duty-phase AND rail AND (support>=K OR first_active_t0_for_band_with_K)
# Rail gate: alt2 parity ((x+y)%2) toggled each tick to sustain neighbor support deterministically.
# From commit maps we build:
#   L1 (B): co-future frontiers & representative paths
#   L2 (C): coherent path clusters (overlap >= tau_coh)
#   L3 (U): unify links (same end OR overlap >= tau_unify)

import argparse, csv, hashlib, json, math, sys
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def sha256_of_file(p: Path):
    h = hashlib.sha256()
    with p.open('rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding='utf-8')
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    with p.open('w', newline='', encoding='utf-8') as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)
def load_json(p: Path):
    if not p.exists(): raise FileNotFoundError(f"Missing file: {p}")
    return json.loads(p.read_text(encoding='utf-8'))

# ---------- geometry / bands ----------
def ring_bounds(nx, ny, outer_margin, frac_bounds):
    R_eff = min(nx, ny)/2.0 - outer_margin
    if R_eff <= 0: R_eff = max(nx,ny)/4.0
    return [(f0*R_eff, f1*R_eff) for (f0,f1) in frac_bounds], R_eff

def band_id_for_cell(xc, yc, cx, cy, r_bounds):
    r = math.hypot(xc - cx, yc - cy)
    for b,(r0,r1) in enumerate(r_bounds):
        if r0 <= r < r1: return b
    return -1

# ---------- rail gate ----------
def rail_allows(ix, iy, t, mode, phase):
    if mode == "none": return True
    if mode == "alt2":
        parity = (ix + iy) & 1  # 0/1
        return (t & 1) == (parity ^ (phase & 1))
    return True

# ---------- engine ----------
def build_commit_maps(nx, ny, H, r_bounds, controls):
    cx, cy = (nx-1)/2.0, (ny-1)/2.0
    band = [[-1]*nx for _ in range(ny)]
    for iy in range(ny):
        for ix in range(nx):
            band[iy][ix] = band_id_for_cell(ix+0.5, iy+0.5, cx, cy, r_bounds)

    period  = [c["period"] for c in controls]
    duty    = [c["duty"] for c in controls]
    kthresh = [c["neighbor_threshold"] for c in controls]
    cool    = [c["cooldown_steps"] for c in controls]
    rmode   = [c.get("rail_mode","alt2") for c in controls]
    rphase  = [c.get("rail_phase",0)     for c in controls]

    commit_prev = [[0]*nx for _ in range(ny)]
    cooldown    = [[0]*nx for _ in range(ny)]
    commit_maps = []

    for t in range(H):
        now = [[0]*nx for _ in range(ny)]
        for iy in range(ny):
            for ix in range(nx):
                b = band[iy][ix]
                if b < 0: continue
                if cooldown[iy][ix] != 0:
                    eligible = False
                else:
                    duty_on = ((t % period[b]) < duty[b])
                    rail_on = rail_allows(ix, iy, t, rmode[b], rphase[b])
                    eligible = duty_on and rail_on
                    if eligible:
                        # 4-neighbor support from t-1
                        supp = 0
                        if ix+1 < nx: supp += commit_prev[iy][ix+1]
                        if ix-1 >= 0: supp += commit_prev[iy][ix-1]
                        if iy+1 < ny: supp += commit_prev[iy+1][ix]
                        if iy-1 >= 0: supp += commit_prev[iy-1][ix]
                        need = (kthresh[b] > 0)
                        first_active = (t % period[b]) == 0 and duty_on
                        if need and first_active and t == 0:
                            # t=0 bootstrap for K>0 bands on first active tick
                            pass
                        else:
                            eligible = (supp >= kthresh[b])
                if eligible:
                    now[iy][ix] = 1
                    cooldown[iy][ix] = cool[b]
                else:
                    if cooldown[iy][ix] > 0: cooldown[iy][ix] -= 1

        commit_maps.append(now)
        commit_prev = now
    return commit_maps, band

# ---------- branching (L1) ----------
def neighbors_3fan(nx, ny, x, y):
    cand = []
    if y+1 < ny:
        if x-1 >= 0: cand.append((x-1,y+1))
        cand.append((x,y+1))
        if x+1 < nx: cand.append((x+1,y+1))
    return cand

def build_branches_from_seeds(commit_maps, seeds, H_alt, max_cand_per_step, max_end_branches):
    ny, nx = len(commit_maps[0]), len(commit_maps[0][0])
    depths = min(H_alt, len(commit_maps))

    B_counts = []           # (seed_id, depth, n_candidates)
    end_paths = []          # (seed_id, path[(x,y)...])

    for sid,(sx,sy) in enumerate(seeds):
        sx, sy = min(max(0,sx), nx-1), min(max(0,sy), ny-1)
        frontier = {(sx,sy)}
        B_counts.append((sid, 0, len(frontier)))
        parent = {}  # (x,y,d) -> (px,py,d-1)

        for d in range(depths):
            allowed = commit_maps[d]
            nxt = []
            for (x,y) in sorted(frontier):
                for (nx1,ny1) in neighbors_3fan(nx, ny, x, y):
                    if allowed[ny1][nx1] == 1:
                        nxt.append((nx1,ny1))
                        key = (nx1,ny1,d+1)
                        if key not in parent:
                            parent[key] = (x,y,d)
            frontier = set(sorted(nxt))
            if len(frontier) > max_cand_per_step:
                frontier = set(list(sorted(frontier))[:max_cand_per_step])
            B_counts.append((sid, d+1, len(frontier)))
            if not frontier: break

        ends = sorted(list(frontier))[:max_end_branches]
        # representative paths via the parent links
        for ex,ey in ends:
            path = [(ex,ey)]
            cd,cx,cy = B_counts[-1][1], ex, ey
            while cd > 0:
                p = parent.get((cx,cy,cd))
                if not p: break
                px,py,pd = p
                path.append((px,py))
                cx,cy,cd = px,py,pd
            path.reverse()
            end_paths.append((sid, path))
    return B_counts, end_paths

# ---------- L2 (coherent sets) & L3 (unify) ----------
def path_overlap_ratio(pa, pb):
    n = min(len(pa), len(pb))
    if n == 0: return 0.0
    return sum(1 for i in range(n) if pa[i]==pb[i]) / float(n)

class UF:
    def __init__(self,n): self.p=list(range(n)); self.sz=[1]*n
    def f(self,a):
        while self.p[a]!=a: self.p[a]=self.p[self.p[a]]; a=self.p[a]
        return a
    def u(self,a,b):
        ra,rb=self.f(a),self.f(b)
        if ra==rb: return False
        if self.sz[ra]<self.sz[rb]: ra,rb=rb,ra
        self.p[rb]=ra; self.sz[ra]+=self.sz[rb]; return True

def build_L2_sets(end_paths, tau_coh):
    n = len(end_paths)
    if n == 0: return [], []
    uf = UF(n)
    for i in range(n):
        for j in range(i+1,n):
            if path_overlap_ratio(end_paths[i][1], end_paths[j][1]) >= tau_coh:
                uf.u(i,j)
    comp = {}
    for i in range(n):
        r = uf.f(i); comp.setdefault(r, []).append(i)
    set_ids = [0]*n; stats=[]
    for new_id,(r, idxs) in enumerate(comp.items()):
        for i in idxs: set_ids[i]=new_id
        # mean overlap within set
        ks=len(idxs); overlaps=[]; seeds=set()
        for a in range(ks):
            ia=idxs[a]; seeds.add(end_paths[ia][0])
            for b in range(a+1,ks):
                ib=idxs[b]
                overlaps.append(path_overlap_ratio(end_paths[ia][1], end_paths[ib][1]))
        stats.append({"size": ks,
                      "mean_overlap": (sum(overlaps)/len(overlaps)) if overlaps else 1.0,
                      "seeds": sorted(list(seeds))})
    return set_ids, stats

def build_L3_unify(set_ids, end_paths, tau_unify):
    n = len(end_paths)
    if n == 0: return [], [[ ]]  # trivial
    sets = {}
    for i,sid in enumerate(set_ids): sets.setdefault(sid, []).append(i)
    edges=[]
    keys=sorted(sets.keys())
    for a in range(len(keys)):
        for b in range(a+1,len(keys)):
            sa,sb=keys[a],keys[b]
            best=0.0; tie=False
            for i in sets[sa]:
                for j in sets[sb]:
                    pa,pb=end_paths[i][1], end_paths[j][1]
                    if pa and pb and pa[-1]==pb[-1]: tie=True
                    best=max(best, path_overlap_ratio(pa,pb))
            if tie: edges.append((sa,sb,"tie_end",1.0))
            elif best>=tau_unify: edges.append((sa,sb,"overlap",float(best)))
    # components
    if not edges: return edges, [[k] for k in keys]
    uf=UF(len(keys)); idx={k:i for i,k in enumerate(keys)}
    for (u,v,_,_) in edges: uf.u(idx[u], idx[v])
    groups={}
    for k in keys:
        r=uf.f(idx[k]); groups.setdefault(r, []).append(k)
    return edges, list(groups.values())

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--manifest', required=True)
    ap.add_argument('--diag', required=True)
    ap.add_argument('--out', required=True)
    args = ap.parse_args()

    out = Path(args.out)
    metr, aud, info = out/'metrics', out/'audits', out/'run_info'
    for d in (metr,aud,info): ensure_dir(d)

    manifest = load_json(Path(args.manifest))
    diag     = load_json(Path(args.diag))

    nx = int(manifest.get('domain',{}).get('grid',{}).get('nx',256))
    ny = int(manifest.get('domain',{}).get('grid',{}).get('ny',256))
    H  = int(manifest.get('domain',{}).get('ticks',128))

    outer_margin = int(diag.get('ring',{}).get('outer_margin', 8))
    frac_bounds  = diag.get('bands',{}).get('frac_bounds',
                    [[0.00,0.35],[0.35,0.60],[0.60,0.85],[0.85,1.00]])
    r_bounds, R_eff = ring_bounds(nx, ny, outer_margin, frac_bounds)

    # Controls per band (rails + cooldown + K)
    controls = diag.get('controls',{}).get('per_band', [
        {"period":6, "duty":3, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},  # β-2
        {"period":6, "duty":2, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},  # β-1
        {"period":8, "duty":1, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},  # β0
        {"period":10,"duty":1, "neighbor_threshold":0, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0}   # β+1
    ])

    # Build commit maps (REAL)
    commit_maps, _ = build_commit_maps(nx, ny, H, r_bounds, controls)

    # Seeds & branching params
    seeds_cfg = diag.get('seeds', {})
    mode = seeds_cfg.get('mode', 'grid')         # default to small grid for richer structure
    H_alt = int(diag.get('horizon', {}).get('alt', 64))
    max_cand = int(diag.get('branching', {}).get('max_candidates_per_step', 32))
    max_ends = int(diag.get('branching', {}).get('max_end_branches', 256))

    # make seeds
    seeds=[]
    if mode == 'center':
        seeds=[(nx//2, ny//2)]
    elif mode == 'list':
        pts = seeds_cfg.get('points', [])
        seeds = [(int(p['x']), int(p['y'])) for p in pts] if pts else [(nx//2,ny//2)]
    else:  # grid (default)
        step = int(seeds_cfg.get('grid_step', 64))
        for y in range(step//2, ny, step):
            for x in range(step//2, nx, step):
                seeds.append((x,y))

    # L1
    B_counts, end_paths = build_branches_from_seeds(commit_maps, seeds, H_alt, max_cand, max_ends)

    # L2
    tau_coh = float(diag.get('coherence', {}).get('tau_coh', 0.75))
    set_ids, set_stats = build_L2_sets(end_paths, tau_coh)

    # L3
    tau_unify = float(diag.get('unify', {}).get('tau_unify', 0.50))
    unify_edges, comps = build_L3_unify(set_ids, end_paths, tau_unify)

    # -------- write artifacts --------
    write_csv(metr/'cofut_local.csv', ['seed_id','depth','n_candidates'], B_counts)

    rows_end=[]
    for bid,(sid,path) in enumerate(end_paths):
        ex,ey = path[-1] if path else (-1,-1)
        rows_end.append([bid, sid, max(0, len(path)-1), ex, ey, len(path)])
    write_csv(metr/'branches_last.csv', ['branch_id','seed_id','depth','end_x','end_y','path_len'], rows_end)

    rows_sets=[]
    for sid_stat,stat in enumerate(set_stats):
        rows_sets.append([sid_stat, stat['size'], round(stat['mean_overlap'],6), len(stat['seeds'])])
    write_csv(metr/'coherent_sets.csv', ['set_id','size','mean_overlap','seeds_covered'], rows_sets)

    rows_unify=[]
    for (u,v,link,w) in unify_edges:
        rows_unify.append([u,v,link, round(w,6)])
    write_csv(metr/'unify_trace.csv', ['u_set','v_set','link_type','weight'], rows_unify)

    write_json(metr/'alt_horizon.json', {"H_alt": H_alt})

    # PASS logic
    B_ok  = any((d>0 and n>0) for (_,d,n) in B_counts)
    L2_ok = (len(set_stats)>0 and any(s['size']>=2 for s in set_stats)) or (len(end_paths)>0)
    U_ok  = (len(comps)>=1)  # components exist (edges may be 0)
    PASS  = bool(B_ok and L2_ok and U_ok)

    write_json((out/'audits'/'relations_c8.json'), {
        "nx": nx, "ny": ny, "H": H, "H_alt": H_alt,
        "seeds": seeds, "n_end_paths": len(end_paths),
        "coherent_sets": len(set_stats),
        "unify_edges": len(unify_edges), "unify_components": len(comps),
        "thresholds": {"tau_coh": tau_coh, "tau_unify": tau_unify},
        "checks": {"B_ok": B_ok, "L2_ok": L2_ok, "U_ok": U_ok},
        "PASS": PASS
    })

    write_json((out/'run_info'/'hashes.json'), {
        "manifest_hash": sha256_of_file(Path(args.manifest)),
        "diag_hash":     sha256_of_file(Path(args.diag)),
        "engine_entrypoint": f"python {Path(sys.argv[0]).name} --manifest <...> --diag <...> --out <...>"
    })

    print("C8 SUMMARY (REAL v2):", json.dumps({
        "H_alt": H_alt,
        "n_branches": len(end_paths),
        "n_L2_sets": len(set_stats),
        "n_unify_edges": len(unify_edges),
        "PASS": PASS,
        "audit_path": str((out/'audits'/'relations_c8.json').as_posix())
    }))

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        out_dir = None
        for i,a in enumerate(sys.argv):
            if a == '--out' and i+1 < len(sys.argv):
                out_dir = Path(sys.argv[i+1]); break
        if out_dir:
            ensure_dir(out_dir/'audits')
            write_json(out_dir/'audits'/'relations_c8.json',
                       {"PASS": False, "failure_reason": f"Unexpected error: {type(e).__name__}: {e}"})
        raise